{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "grBmytrShbUE"
      },
      "source": [
        "# High-performance simulations with TFF\n",
        "\n",
        "This tutorial will describe how to setup high-performance simulations with TFF\n",
        "in a variety of common scenarios.\n",
        "\n",
        "NOTE: The mechanisms covered here are not included in the latest release, have\n",
        "not been tested yet, and the API may evolve. In order to follow this tutorial,\n",
        "you will need to build a TFF pip package from scratch from the latest sources, and install it in a Jupyter notebook with a Python 3 runtime. The new executor\n",
        "stack is not compatible with Python 2.\n",
        "\n",
        "TODO(b/134543154): Populate the content, some of the things to cover here:\n",
        "- using GPUs in a single-machine setup,\n",
        "- multi-machine setup on GCP/GKE, with and without TPUs,\n",
        "- interfacing MapReduce-like backends,\n",
        "- current limitations and when/how they will be relaxed."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "yiq_MY4LopET"
      },
      "source": [
        "## Before we begin\n",
        "\n",
        "First, make sure your notebook is connected to a backend that has the relevant\n",
        "components (including gRPC dependencies for multi-machine scenarios) compiled.\n",
        "\n",
        "Now, if you are running this notebook in Jupyter, you may need to take an extra\n",
        "step to work around the\n",
        "[limitations of Jupyter with asyncio](https://github.com/jupyter/notebook/issues/3397#issuecomment-419386811)\n",
        "by installing the [nest_asyncio](https://github.com/erdewit/nest_asyncio)\n",
        "package."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "tnngjncsPq15"
      },
      "outputs": [],
      "source": [
        "!pip install --upgrade nest_asyncio\n",
        "import nest_asyncio\n",
        "nest_asyncio.apply()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_zFenI3IPpgI"
      },
      "source": [
        "Now, let's start by loading the MNIST example from the TFF website, and\n",
        "declaring the Python function that will run a small experiment loop over\n",
        "a group of 10 clients."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ke7EyuvG0Zyn"
      },
      "outputs": [],
      "source": [
        "!pip install --upgrade tensorflow_federated\n",
        "!pip install --upgrade tf-nightly"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "2dVPgxN0MdG2"
      },
      "outputs": [],
      "source": [
        "import collections\n",
        "import warnings\n",
        "import time\n",
        "\n",
        "import tensorflow as tf\n",
        "tf.compat.v1.enable_v2_behavior()\n",
        "\n",
        "import tensorflow_federated as tff\n",
        "\n",
        "warnings.simplefilter('ignore')\n",
        "\n",
        "source, _ = tff.simulation.datasets.emnist.load_data()\n",
        "\n",
        "def map_fn(example):\n",
        "  return {'x': tf.reshape(example['pixels'], [-1]), 'y': example['label']}\n",
        "\n",
        "def client_data(n):\n",
        "  ds = source.create_tf_dataset_for_client(source.client_ids[n])\n",
        "  return ds.repeat(10).map(map_fn).shuffle(500).batch(20)\n",
        "\n",
        "train_data = [client_data(n) for n in range(10)]\n",
        "\n",
        "batch = tf.nest.map_structure(lambda x: x.numpy(), iter(train_data[0]).next())\n",
        "\n",
        "def model_fn():\n",
        "  model = tf.keras.models.Sequential([\n",
        "      tf.keras.layers.Flatten(input_shape=(784,)),\n",
        "      tf.keras.layers.Dense(10, tf.nn.softmax, kernel_initializer='zeros')\n",
        "  ])\n",
        "  model.compile(\n",
        "      loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n",
        "      optimizer=tf.keras.optimizers.SGD(0.02),\n",
        "      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])\n",
        "  return tff.learning.from_compiled_keras_model(model, batch)\n",
        "\n",
        "trainer = tff.learning.build_federated_averaging_process(model_fn)\n",
        "\n",
        "def evaluate(num_rounds=10):\n",
        "  state = trainer.initialize()\n",
        "  for _ in range(num_rounds):\n",
        "    t1 = time.time()\n",
        "    state, metrics = trainer.next(state, train_data)\n",
        "    t2 = time.time()\n",
        "    print('loss {}, round time {}'.format(metrics.loss, t2 - t1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "CDHJF7EIiEy-"
      },
      "source": [
        "## Single-machine simulations\n",
        "\n",
        "A simple local multi-threaded executor can be created using a new currently\n",
        "undocumented framework function `tff.framework.create_local_executor()`, and\n",
        "made default by calling `tff.framework.set_default_executor()`, as follows."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 185
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 326915,
          "status": "ok",
          "timestamp": 1568008145631,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 420
        },
        "id": "-V6uCS_BMoR9",
        "outputId": "dc9a34df-96f1-41da-a9ca-2328cc656506"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "loss 2.9510040283203125, round time 49.65723657608032\n",
            "loss 2.777134656906128, round time 45.5357563495636\n",
            "loss 2.5103652477264404, round time 29.720882892608643\n",
            "loss 2.2921206951141357, round time 30.4314706325531\n",
            "loss 2.0617873668670654, round time 32.21593737602234\n",
            "loss 1.9325430393218994, round time 43.6105010509491\n",
            "loss 1.7762397527694702, round time 23.19011878967285\n",
            "loss 1.6028356552124023, round time 25.11474061012268\n",
            "loss 1.5010586977005005, round time 24.695493936538696\n",
            "loss 1.4369142055511475, round time 22.34806251525879\n"
          ]
        }
      ],
      "source": [
        "tff.framework.set_default_executor(tff.framework.create_local_executor())\n",
        "\n",
        "evaluate()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "6NnVUd6qM6h-"
      },
      "source": [
        "The reference executor can be automatically installed back by calling\n",
        "the `tff.framework.set_default_executor()` function without an argument."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "W9GVbYCKkIYU"
      },
      "outputs": [],
      "source": [
        "tff.framework.set_default_executor()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "bZ171NhcNa3M"
      },
      "source": [
        "## Multi-machine simulations on GCP/GKE, GPUs, TPUs, and beyond...\n",
        "\n",
        "Coming very soon."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      },
      "name": "High-performance simulations with TFF",
      "provenance": [
        {
          "file_id": "14vSn6H8hu35BMb48b48hHYJ3Lln3OTQL",
          "timestamp": 1561680139142
        }
      ],
      "version": "0.3.2"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
